{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.sys.path.append(os.path.dirname(os.path.abspath('.')))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 数据准备"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from datasets.dataset import load_breast_cancer\n",
    "from model_selection.train_test_split import train_test_split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "data = load_breast_cancer()\n",
    "X = data.data\n",
    "Y = data.target\n",
    "\n",
    "X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## KNN\n",
    "首先定义一个距离度量函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 欧氏距离\n",
    "def E_dist(a:list,b:list):\n",
    "    a=np.array(a)\n",
    "    b=np.array(b)\n",
    "    return np.linalg.norm(a-b)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "KNN模型不需要训练，他把训练数据存储起来，给定一个测试(验证)样本，直接在训练样本库中对比、搜索出最近的邻居所对应的标签。\n",
    "\n",
    "首先定义查找邻居的函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # 返回单个测试样本的前k个邻居\n",
    "# def get_nb_of_one(X_train, Y_train, x_test, k=3, dist=E_dist):\n",
    "#     dists = []\n",
    "\n",
    "#     for idx in range(len(X_train)):\n",
    "#         cur_dist = dist(x_test, X_train[idx])\n",
    "#         dists.append((Y_train[idx], cur_dist))    # 首位为标签，末位为距离\n",
    "#     dists.sort(key=lambda x: x[1])    # 按照距离排序\n",
    "\n",
    "#     return dists[:k]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "返回前k个邻居之后，这些邻居会有不同的标签，还需要投票选出出现次数最多的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # 投票选出neighbors中出现次数最多的标签\n",
    "# from collections import Counter\n",
    "\n",
    "\n",
    "# def Vote(neighbors):\n",
    "#     counter = Counter()\n",
    "#     for idx in range(len(neighbors)):\n",
    "#         dist=neighbors[idx][1]\n",
    "#         counter[label] += 1/(dist+1)    # 首位(标签)计数，权重为距离的倒数\n",
    "#     return counter.most_common(1)[0][0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这样一来，就可以返回单个样本在训练数据库中最近邻邻居的标签了，即预测标签。接下来再定义一个有批量预测功能的函数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def predict(X_test):\n",
    "#     Y_pred=[]\n",
    "#     for x_test in X_test:\n",
    "#         neighbors=get_nb_of_one(X_train,Y_train,x_test)\n",
    "#         Y_pred.append(Vote(neighbors))\n",
    "#     return np.array(Y_pred)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 封装\n",
    "基本功能都完成之后，把这些模块都封装起来，实现一个类SKlearn的KNN类。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from collections import Counter\n",
    "\n",
    "\n",
    "class KNN:\n",
    "    def __init__(self, n_neighbors=5, metric=E_dist):\n",
    "        self.X_train = None\n",
    "        self.Y_train = None\n",
    "        self.k = n_neighbors\n",
    "        self.metric = metric\n",
    "\n",
    "    def fit(self, X_train, Y_train):\n",
    "        # 模型不改变输入数据，所以这里等号赋值没有问题\n",
    "        self.X_train = X_train\n",
    "        self.Y_train = Y_train\n",
    "\n",
    "    def __get_nb_of_one(self, x_test):\n",
    "        dists = []\n",
    "\n",
    "        for idx in range(len(self.X_train)):\n",
    "            cur_dist = self.metric(x_test, self.X_train[idx])\n",
    "            dists.append((self.Y_train[idx], cur_dist))    # 首位为标签，末位为距离\n",
    "        dists.sort(key=lambda x: x[1])    # 按照距离排序\n",
    "\n",
    "        return dists[:self.k]\n",
    "\n",
    "    def __vote(self, neighbors):\n",
    "        counter = Counter()\n",
    "        for idx in range(len(neighbors)):\n",
    "            dist = neighbors[idx][1]\n",
    "            label = neighbors[idx][0]\n",
    "            counter[label] += 1/(dist+1)    # 首位(标签)计数，权重为距离的倒数\n",
    "        return counter.most_common(1)[0][0]\n",
    "\n",
    "    def predict(self, X_test):\n",
    "        Y_pred = []\n",
    "        for x_test in X_test:\n",
    "            neighbors = self.__get_nb_of_one(x_test)\n",
    "            Y_pred.append(self.__vote(neighbors))\n",
    "        return np.array(Y_pred)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "acc:0.956140350877193\n"
     ]
    }
   ],
   "source": [
    "knn = KNN()\n",
    "knn.fit(X_train, Y_train)\n",
    "Y_pred = knn.predict(X_test)\n",
    "print('acc:{}'.format(np.sum(Y_pred == Y_test)/len(Y_test)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "完整代码中使用了KD树进行优化。"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
